#-*-coding:utf-8-*-
# date:2021-04-15
# author: likecy
# function : model utils

import os
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import random




def f1_sorce(y_true:torch.Tensor, y_pred:torch.Tensor, is_training=False) -> torch.Tensor:
    '''Calculate F1 score. Can work with gpu tensors
    
    The original implmentation is written by Michal Haltuf on Kaggle.
    
    Returns
    -------
    torch.Tensor
        `ndim` == 1. 0 <= val <= 1
    
    Reference
    ---------
    - https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric
    - https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn.metrics.f1_score
    - https://discuss.pytorch.org/t/calculating-precision-recall-and-f1-score-in-case-of-multi-label-classification/28265/6
    
    '''
    assert y_true.ndim == 1
    assert y_pred.ndim == 1 or y_pred.ndim == 2
    
    if y_pred.ndim == 2:
        y_pred = y_pred.argmax(dim=1)
        
    
    tp = (y_true * y_pred).sum().to(torch.float32)
    tn = ((1 - y_true) * (1 - y_pred)).sum().to(torch.float32)
    fp = ((1 - y_true) * y_pred).sum().to(torch.float32)
    fn = (y_true * (1 - y_pred)).sum().to(torch.float32)
    
    epsilon = 1e-7
    
    precision = tp / (tp + fp + epsilon)
    recall = tp / (tp + fn + epsilon)
    
    f1 = 2* (precision*recall) / (precision + recall + epsilon)
    f1.requires_grad = is_training
    return f1



def get_acc(output, label):
    total = output.shape[0]
    _, pred_label = output.max(1)
    num_correct = (pred_label == label).sum().item()
    return num_correct / float(total)

def set_learning_rate(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def set_seed(seed = 666):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        cudnn.deterministic = True

def split_trainval_datasets(ops):
    print(' --------------->>> split_trainval_datasets ')
    train_split_datasets = []
    train_split_datasets_label = []

    val_split_datasets = []
    val_split_datasets_label = []
    for idx,doc in enumerate(sorted(os.listdir(ops.train_path), key=lambda x:int(x.split('-')[0]), reverse=False)):
        # print(' %s label is %s \n'%(doc,idx))

        data_list = os.listdir(ops.train_path+doc)
        random.shuffle(data_list)

        cal_split_num = int(len(data_list)*ops.val_factor)

        for i,file in enumerate(data_list):
            if '.jpg' in file:
                if i < cal_split_num:
                    val_split_datasets.append(ops.train_path+doc + '/' + file)
                    val_split_datasets_label.append(idx)
                else:
                    train_split_datasets.append(ops.train_path+doc + '/' + file)
                    train_split_datasets_label.append(idx)

                # print(ops.train_path+doc + '/' + file,idx)

    print('\n')
    print('train_split_datasets len {}'.format(len(train_split_datasets)))
    print('val_split_datasets len {}'.format(len(val_split_datasets)))

    return train_split_datasets,train_split_datasets_label,val_split_datasets,val_split_datasets_label


def split_test_datasets(ops):
    print(' --------------->>> getting_test_datasets ')
    test_split_datasets = []
    test_split_datasets_label = []
    for idx,doc in enumerate(sorted(os.listdir(ops.test_path), key=lambda x:int(x.split('-')[0]), reverse=False)):
        # print(' %s label is %s \n'%(doc,idx))
        data_list = os.listdir(ops.test_path+doc)
        random.shuffle(data_list)
        for i,file in enumerate(data_list):
            if '.jpg' in file:
                test_split_datasets.append(ops.test_path + doc + '/' + file)
                test_split_datasets_label.append(idx)
    print('\n')
    print('test_split_datasets len {}'.format(len(test_split_datasets)))

    return test_split_datasets,test_split_datasets_label
